#include <sys/cdefs.h>
#include "esmd_types.h"
#include "listed.hpp"
#include "geometry.hpp"

__always_inline real harmonic_bond_param::calc(fref<2> &f, xref<2> &x)
{
  vec<real> r = (x[0] - x[1]);
  real rnorm = r.norm();
  real inv_rnorm = 1 / rnorm;
  real e = kr * (rnorm - r0) * (rnorm - r0);
  real dedr = 2.0 * kr * (rnorm - r0);
    
  f[0] -= dedr * r * inv_rnorm;
  f[1] += dedr * r * inv_rnorm;
  //printf("f0: %f %f %f\n", f[0].x, f[0].y, f[0].z);
  return e;
}
__always_inline real harmonic_angle_param::calc(fref<3> &f, xref<3> &x) {
  angle_geometry theta(x);
  real rad = theta.rad();
  real e = ktheta * (rad - theta0) * (rad - theta0);
  real dedtheta = 2 * ktheta * (rad - theta0);
  f[0] += dedtheta * theta.ffactor[0];
  f[1] += dedtheta * theta.ffactor[1];
  f[2] += dedtheta * theta.ffactor[2];
  return e;
}
__always_inline real urey_bradly_param::calc(fref<3> &f, xref<3> &x){
  angle_geometry theta(x);
  real rad = theta.rad();
  real etheta = ktheta * (rad - theta0) * (rad - theta0);
  real dedtheta = 2 * ktheta * (rad - theta0);
  f[0] += dedtheta * theta.ffactor[0];
  f[1] += dedtheta * theta.ffactor[1];
  f[2] += dedtheta * theta.ffactor[2];
  vec<real> r02 = x[0] - x[2];
  real r02norm = r02.norm();
  real inv_r02norm = 1 / r02norm;
  real eub = kub * (r02norm - r0ub) * (r02norm - r0ub);
  real dedr = 2 * kub * (r02norm - r0ub);
  f[0] -= dedr * r02 * inv_r02norm;
  f[2] += dedr * r02 * inv_r02norm;
  return etheta + eub;
}
__always_inline real cosine_torsion_param::calc(fref<4> &f, xref<4> &x) {
  dihedral_geometry phi(x);
  #ifndef __sw_slave__
  real e = kphi * (1 + cos(np * phi.rad() - phi0));
  real dedphi = -np * kphi * sin(np * phi.rad() - phi0);
  #else
  real e = kphi * (1 + pcos(np * phi.rad() - phi0));
  real dedphi = -np * kphi * psin(np * phi.rad() - phi0);

  #endif
  f[0] += dedphi * phi.ffactor[0];
  f[1] += dedphi * phi.ffactor[1];
  f[2] += dedphi * phi.ffactor[2];
  f[3] += dedphi * phi.ffactor[3];
  return e;
}

__always_inline real ryckaert_bellmans_param::calc(fref<4> &f, xref<4> &x)  {
  dihedral_geometry phi(x);
  real cosphi = phi.cosphi;
  real e = c5 * cosphi + c4;
  e = e * cosphi + c3;
  e = e * cosphi + c2;
  e = e * cosphi + c1;
  e = e * cosphi + c0;

  real dedcos = c5 * cosphi * 5.0 + c4 * 4.0;
  dedcos = dedcos * cosphi + c3 * 3.0;
  dedcos = dedcos * cosphi + c2 * 2.0;
  dedcos = dedcos * cosphi + c1;

  real dedphi = -phi.sinphi * dedcos;

  f[0] += dedphi * phi.ffactor[0];
  f[1] += dedphi * phi.ffactor[1];
  f[2] += dedphi * phi.ffactor[2];
  f[3] += dedphi * phi.ffactor[3];
  return e;
}
__always_inline real extended_ryckaert_bellmans_param::calc(fref<4> &f, xref<4> &x) {
  dihedral_geometry phi(x);
  real cosphi = phi.cosphi;
  real e = c6 * cosphi + c5;
  e = e * cosphi + c4;
  e = e * cosphi + c3;
  e = e * cosphi + c2;
  e = e * cosphi + c1;
  e = e * cosphi + c0;

  real dedcos = c6 * cosphi * 6.0 + c5 * 5.0;
  dedcos = dedcos * cosphi + c4 * 4.0;
  dedcos = dedcos * cosphi + c3 * 3.0;
  dedcos = dedcos * cosphi + c2 * 2.0;
  dedcos = dedcos * cosphi + c1;

  real dedphi = -phi.sinphi * dedcos;

  f[0] += dedphi * phi.ffactor[0];
  f[1] += dedphi * phi.ffactor[1];
  f[2] += dedphi * phi.ffactor[2];
  f[3] += dedphi * phi.ffactor[3];
  return e;
}
__always_inline real harmonic_improper_param::calc(fref<4> &f, xref<4> &x) {
  dihedral_geometry phi(x);
  real rad = phi.rad();

  real e = kpsi * (rad - psi0) * (rad - psi0);
  real dedphi = 2 * kpsi * (rad - psi0);
  #ifdef IMPR_TEST
  printf("%f %f %f %f %f\n", rad, e, phi.sinphi, phi.cosphi, dedphi);
  #endif
  f[0] += dedphi * phi.ffactor[0];
  f[1] += dedphi * phi.ffactor[1];
  f[2] += dedphi * phi.ffactor[2];
  f[3] += dedphi * phi.ffactor[3];
  return e;
}
__always_inline extended_ryckaert_bellmans_param &extended_ryckaert_bellmans_param::operator+=(const cosine_torsion_param &p) {
  double factor;
  if (fabs(p.phi0 - M_PI) < 1e-6) {
    factor = -p.kphi;
  } else if (abs(p.phi0) < 1e-6) {
    factor = p.kphi;
  } else {
    // print_type(p);
    printf("Not supported dihedral parameter for eRB.\n");
  }
  c0 += p.kphi;
  switch (p.np) {
  case 1:
    c1 += 1.0 * factor;
    break;
  case 2:
    c0 += -1.0 * factor;
    c2 += 2.0 * factor;
    break;
  case 3:
    c1 += -3.0 * factor;
    c3 += 4.0 * factor;
    break;
  case 4:
    c0 += 1.0 * factor;
    c2 += -8.0 * factor;
    c4 += 8.0 * factor;
    break;
  case 5:
    c1 += 5.0 * factor;
    c3 += -20.0 * factor;
    c5 += 16.0 * factor;
    break;
  case 6:
    c0 += -1.0 * factor;
    c2 += 18.0 * factor;
    c4 += -48.0 * factor;
    c6 += 32.0 * factor;
    break;
  default:
    printf("periodicity is too large for extended rb potential!\n");
  }
  return *this;
}
#ifndef __sw_slave__
template <typename ParamType, typename IdType, int Capacity>
template<typename TagMapType> 
void listed_force_cell_ents<ParamType, IdType, Capacity>::export_entries(const TagMapType &tag_map) {
  int ptr_keep = 0, ptr_move = Capacity;
  for (int i = 0; i < cnt; i++) {
    if (tag_map.contains(entries[i].id[ParamType::owner])) {
      entries[ptr_keep++] = entries[i];
    } else {
      entries[--ptr_move] = entries[i];
    }
  }
  cnt = ptr_keep;
  nexport = Capacity - ptr_move;
}
#endif
template <typename ParamType, typename IdType, int Capacity>
template<typename TagMapType> 
void listed_force_cell_ents<ParamType, IdType, Capacity>::import_entries(const TagMapType &tag_map, listed_force_cell_ents<ParamType, IdType, Capacity> &from){
  for (int i = Capacity - from.nexport; i < Capacity; i++) {
    if (tag_map.contains(from.entries[i].id[ParamType::owner])) {
      entries[cnt++] = from.entries[i];
      //assert(from.entries[i].pid < 1000);
    }
  }
  if (cnt + nexport > Capacity) {
    printf("%d %d\n", cnt, nexport);
    assert(cnt + nexport <= Capacity);
  }
}
template <typename ParamType, typename IdType, int Capacity>
void listed_force_cell_ents<ParamType, IdType, Capacity>::pack_exported (pack_buffer &buffer){
  buffer.append(nexport);
  buffer.append(&entries[Capacity - nexport], nexport);
  for (int i = Capacity - nexport; i < Capacity; i ++){
    assert(entries[i].pid < 1000);
  }
}

template <typename ParamType, typename IdType, int Capacity>
void listed_force_cell_ents<ParamType, IdType, Capacity>::unpack_exported(unpack_buffer &buffer) {
  buffer.extract(nexport);
  buffer.extract(&entries[Capacity - nexport], nexport);
  for (int i = Capacity - nexport; i < Capacity; i ++){
    assert(entries[i].pid < 1000);
  }
}

template <typename ParamType, typename IdType, int Capacity>
template <int... Is>
real listed_force_cell_ents<ParamType, IdType, Capacity>::calc(vec<real> *f, vec<real> *x, ParamType *params, utils::iseq<int, Is...> seq){
  real e = 0;
  for (int i = 0; i < cnt; i ++){
    auto &entry = entries[i];
    xref<ParamType::nbodies> xin(x[entry.id[Is]]...);
    fref<ParamType::nbodies> fout(f[entry.id[Is]]...);
    ParamType &param = params[entry.pid];
    e += param.calc(fout, xin);

    utils::expand{
      ((f[entry.id[Is]] += fout[Is]), 0)...
        };
  }
  return e;
}
template <typename ParamType, typename IdType, int Capacity>
real listed_force_cell_ents<ParamType, IdType, Capacity>::calc(vec<real> *f, vec<real> *x, ParamType *params) {
  return calc(f, x, params, body_seq);
}

template <typename ParamType, typename IdType, int Capacity>
template<typename GuestMapType, typename LocalMapType>
void listed_force_cell_ents<ParamType, IdType, Capacity>::find_guests(GuestMapType &guests, LocalMapType &locals){
  for (int i = 0; i < cnt; i ++){
    for (int j = 0; j < ParamType::nbodies; j ++){
      tagint tag = entries[i].id[j];
      if (!locals.contains(tag) && !guests.contains(tag)) {
        guests.insert(tag);
      }
    }
  }
}
#ifndef __sw_slave__
template <typename ParamType, int Capacity>
template <typename TagMapType, int... Is>
void listed_force_cell<ParamType, Capacity>::relink(const TagMapType &tag_map, utils::iseq<int, Is...> seq) {
  for (int i = 0; i < by_tag.cnt; i++) {
    std::tie(by_id.entries[i].id[Is]..., by_id.entries[i].pid) = std::tie(tag_map[by_tag.entries[i].id[Is]]..., by_tag.entries[i].pid);
    if (by_tag.entries[i].pid > 1000) {
      printf("%d %d %d\n", by_tag.entries[i].id[0], by_tag.entries[i].id[1], i);
      exit(1);
    }
  }
  by_id.cnt = by_tag.cnt;
};
#endif
template <typename ParamType, int Capacity>
template <typename TagMapType>
void listed_force_cell<ParamType, Capacity>::relink(const TagMapType &tag_map) {
  relink(tag_map, utils::make_iseq<int, ParamType::nbodies>{});
}
#ifndef __sw_slave__
template <typename ParamType, int Capacity>
real listed_force_cell<ParamType, Capacity>::calc(vec<real> *f, vec<real> *x, ParamType *params) {
  return by_id.calc(f, x, params, body_seq);
}
#endif

template <typename ParamType, int Capacity>
void listed_force_grid<ParamType, Capacity>::allocate(cellgrid_t &grid) {
  //puts(__PRETTY_FUNCTION__);
  cells = esmd::allocate<listed_force_cell<ParamType, Capacity>>(grid.nall.vol(), "listed force grid");
}
template<typename ParamType, int Capacity>  
void listed_force_grid<ParamType, Capacity>::distribute(esmd::htab<tag_key, vec<int>> &tag_cell_map, listed_force<ParamType> &listed, cellgrid_t &grid) {
  //puts(__PRETTY_FUNCTION__);
  param = listed.param;
  nparam = listed.nparam;
  FOREACH_CELL(&grid, cx, cy, cz, gcell) {
    auto *listed_cell = cells + get_offset_xyz<true>(grid, cx, cy, cz);
    listed_cell->by_tag.cnt = 0;
    listed_cell->by_tag.nexport = 0;
  }
  for (int i = 0; i < listed.topo.size(); i++) {
    auto &top = listed.topo[i];
    if (tag_cell_map.contains(top.id[ParamType::owner])){
      vec<int> loc = tag_cell_map[top.id[ParamType::owner]];
      auto *cell = cells + get_offset_xyz<true>(grid, loc.x, loc.y, loc.z);
      cell->by_tag.entries[cell->by_tag.cnt++] = top;
    }
  }
}

template<typename TagMapType>
void topology_grids::do_export(int offset, TagMapType &tagmap){
  static constexpr auto subseq = utils::make_iseq<int, TOP_END>{};
  do_export(offset, tagmap, subseq);
}

template<typename TagMapType, int... Is>
void topology_grids::do_export(int offset, TagMapType &tagmap, const utils::iseq<int, Is...> &seq){
  utils::expand{
    (present[Is] ? (get<Is>().cells[offset].by_tag.export_entries(tagmap), 0) : 0)...
      };
}
template<typename TagMapType>
void topology_grids::relink(int offset, TagMapType &tagmap){
  static constexpr auto subseq = utils::make_iseq<int, TOP_END>{};
  relink(offset, tagmap, subseq);
}
template<typename TagMapType, int... Is>
void topology_grids::relink(int offset, TagMapType &tagmap, const utils::iseq<int, Is...> &seq){
  utils::expand{
    (present[Is] ? (get<Is>().cells[offset].relink(tagmap), 0) : 0)...
      };
}
template<typename TagMapType, int...Is>
void topology_grids::do_import(int self, int from, TagMapType &tagmap, const utils::iseq<int, Is...> &seq){
  utils::expand{
    (present[Is] ? (get<Is>().cells[self].by_tag.import_entries(tagmap, get<Is>().cells[from].by_tag), 0) : 0)...
      };
}
template<typename TagMapType>
void topology_grids::do_import(int self, int from, TagMapType &tagmap){
  static constexpr auto subseq = utils::make_iseq<int, TOP_END>{};
  do_import(self, from, tagmap, subseq);
}

#ifndef __sw_slave__
template <int I, typename TagMapType, typename GuestMapType>
void topology_grids::do_import_neighbors(int self, TagMapType &tagmap, GuestMapType &guestmap, cellgrid_t *grid, int cx, int cy, int cz){
  if (!present[I]) return;

  FOREACH_NEARNEIGHBOR(grid, cx, cy, cz, dx, dy, dz, jcell){
    if (dx == 0 && dy == 0 && dz == 0) continue;
    int joff = get_offset_xyz<true>(grid, cx + dx, cy + dy, cz + dz);
    get<I>().cells[self].by_tag.import_entries(tagmap, get<I>().cells[joff].by_tag);
  }
  if (I < TOP_GUEST_END) {
    get<I>().cells[self].find_guests(self, guestmap, tagmap);
  }
}
#endif
template <typename TagMapType, typename GuestMapType, int ...Is>
void topology_grids::do_import_neighbors(int self, TagMapType &tagmap, GuestMapType &guestmap, cellgrid_t *grid, int cx, int cy, int cz, utils::iseq<int, Is...> bseq){
  utils::expand{
    (do_import_neighbors<Is>(self, tagmap, guestmap, grid, cx, cy, cz), 0)...
  };
}
template<typename GuestMapType, typename LocalMapType, int ...Is>
void topology_grids::find_guests(int offset, GuestMapType &guests, LocalMapType &locals, const utils::iseq<int, Is...> &seq) {
  utils::expand{
    (present[Is] ? (get<Is>().cells[offset].by_tag.find_guests(guests, locals), 0) : 0)...
      };
}
template<typename GuestMapType, typename LocalMapType>
void topology_grids::find_guests(int offset, GuestMapType &guests, LocalMapType &locals) {
  find_guests(offset, guests, locals, utils::make_iseq<int, TOP_GUEST_END>{});
}
